57914b
@@ -39,7 +39,6 @@
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.hive.conf.HiveConf;
 import org.apache.hadoop.hive.ql.ErrorMsg;
-import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
 import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
 import org.apache.hadoop.hive.ql.exec.Operator;
 import org.apache.hadoop.hive.ql.exec.TableScanOperator;
@@ -62,10 +61,8 @@
 import org.apache.hadoop.hive.ql.parse.SemanticException;
 import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
 import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
-import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
 import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
 import org.apache.hadoop.hive.ql.plan.OperatorDesc;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
 
 /**
  *this transformation does bucket map join optimization.
@@ -182,7 +179,7 @@
private boolean convertBucketMapJoin(Node nd, Stack<Node> stack, NodeProcessorCt
         }
       }
 
-      MapJoinDesc mjDecs = mapJoinOp.getConf();
+      MapJoinDesc mjDesc = mapJoinOp.getConf();
       LinkedHashMap<String, List<Integer>> aliasToPartitionBucketNumberMapping =
           new LinkedHashMap<String, List<Integer>>();
       LinkedHashMap<String, List<List<String>>> aliasToPartitionBucketFileNamesMapping =
@@ -197,6 +194,7 @@
private boolean convertBucketMapJoin(Node nd, Stack<Node> stack, NodeProcessorCt
       LinkedHashMap<Partition, List<String>> bigTblPartsToBucketFileNames = new LinkedHashMap<Partition, List<String>>();
       LinkedHashMap<Partition, Integer> bigTblPartsToBucketNumber = new LinkedHashMap<Partition, Integer>();
 
+      Integer[] orders = null; // accessing order of join cols to bucket cols, should be same
       boolean bigTablePartitioned = true;
       for (int index = 0; index < joinAliases.size(); index++) {
         String alias = joinAliases.get(index);
@@ -204,6 +202,14 @@
private boolean convertBucketMapJoin(Node nd, Stack<Node> stack, NodeProcessorCt
         if (tso == null) {
           return false;
         }
+        List<String> keys = toColumns(mjDesc.getKeys().get((byte) index));
+        if (keys == null || keys.isEmpty()) {
+          return false;
+        }
+        if (orders == null) {
+          orders = new Integer[keys.size()];
+        }
+
         Table tbl = topToTable.get(tso);
         if(tbl.isPartitioned()) {
           PrunedPartitionList prunedParts;
@@ -231,7 +237,7 @@
private boolean convertBucketMapJoin(Node nd, Stack<Node> stack, NodeProcessorCt
             List<Integer> buckets = new ArrayList<Integer>();
             List<List<String>> files = new ArrayList<List<String>>();
             for (Partition p : partitions) {
-              if (!checkBucketColumns(p.getBucketCols(), mjDecs, index)) {
+              if (!checkBucketColumns(p.getBucketCols(), keys, orders)) {
                 return false;
               }
               List<String> fileNames = getOnePartitionBucketFileNames(p.getDataLocation());
@@ -258,7 +264,7 @@
private boolean convertBucketMapJoin(Node nd, Stack<Node> stack, NodeProcessorCt
             }
           }
         } else {
-          if (!checkBucketColumns(tbl.getBucketCols(), mjDecs, index)) {
+          if (!checkBucketColumns(tbl.getBucketCols(), keys, orders)) {
             return false;
           }
           List<String> fileNames = getOnePartitionBucketFileNames(tbl.getDataLocation());
@@ -360,6 +366,17 @@
public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
       return null;
     }
 
+    private List<String> toColumns(List<ExprNodeDesc> keys) {
+      List<String> columns = new ArrayList<String>();
+      for (ExprNodeDesc key : keys) {
+        if (!(key instanceof ExprNodeColumnDesc)) {
+          return null;
+        }
+        columns.add(((ExprNodeColumnDesc) key).getColumn());
+      }
+      return columns;
+    }
+
     // convert partition to partition spec string
     private Map<String, List<String>> convert(Map<Partition, List<String>> mapping) {
       Map<String, List<String>> converted = new HashMap<String, List<String>>();
@@ -433,42 +450,23 @@
private boolean checkBucketNumberAgainstBigTable(
       return fileNames;
     }
 
-    private boolean checkBucketColumns(List<String> bucketColumns, MapJoinDesc mjDesc, int index) {
-      List<ExprNodeDesc> keys = mjDesc.getKeys().get((byte)index);
-      if (keys == null || bucketColumns == null || bucketColumns.size() == 0) {
+    private boolean checkBucketColumns(List<String> bucketColumns, List<String> keys,
+        Integer[] orders) {
+      if (keys == null || bucketColumns == null || bucketColumns.isEmpty()) {
         return false;
       }
-
-      //get all join columns from join keys stored in MapJoinDesc
-      List<String> joinCols = new ArrayList<String>();
-      List<ExprNodeDesc> joinKeys = new ArrayList<ExprNodeDesc>();
-      joinKeys.addAll(keys);
-      while (joinKeys.size() > 0) {
-        ExprNodeDesc node = joinKeys.remove(0);
-        if (node instanceof ExprNodeColumnDesc) {
-          joinCols.addAll(node.getCols());
-        } else if (node instanceof ExprNodeGenericFuncDesc) {
-          ExprNodeGenericFuncDesc udfNode = ((ExprNodeGenericFuncDesc) node);
-          GenericUDF udf = udfNode.getGenericUDF();
-          if (!FunctionRegistry.isDeterministic(udf)) {
-            return false;
-          }
-          joinKeys.addAll(0, udfNode.getChildExprs());
-        } else {
+      for (int i = 0; i < keys.size(); i++) {
+        int index = bucketColumns.indexOf(keys.get(i));
+        if (orders[i] != null && orders[i] != index) {
           return false;
         }
+        orders[i] = index;
       }
-
       // Check if the join columns contains all bucket columns.
       // If a table is bucketized on column B, but the join key is A and B,
       // it is easy to see joining on different buckets yield empty results.
-      if (joinCols.size() == 0 || !joinCols.containsAll(bucketColumns)) {
-        return false;
-      }
-
-      return true;
+      return keys.containsAll(bucketColumns);
     }
-
   }
 
   class BucketMapjoinOptProcCtx implements NodeProcessorCtx {
